library(tidyverse)
library(tidybayes)
library(bayesplot)
library(scales)
library(yardstick)
## For binary classification, the first factor level is assumed to be the event.
## Use the argument `event_level = "second"` to alter this as needed.
## 
## Attaching package: 'yardstick'
## The following object is masked from 'package:readr':
## 
##     spec
library(brms)
library(dbarts)
library(targets)
knitr::opts_knit$set(root.dir = here::here())
tar_load(bas_train)
tar_load(bas_validate)
tar_load(bas_train_prepped)
tar_load(bas_validate_prepped)
tar_load(bas_linear_model)
tar_load(bas_dbarts_model)

Model Summary

bas_linear_model$family
## 
## Family: bernoulli 
## Link function: logit
bas_linear_model$formula
## mpx_pcr_pos ~ vacv_scar + fever_beforerash + rash_palm + rash_soles + lymphadenitis + rash_monomorphic + occ_lesion_un + occ_lesion_bi + sex_male + agecat2 + agecat3 + agecat4 + agecat5 + fever_beforerash_rash_palm + fever_beforerash_occ_lesion_un + rash_palm_rash_soles + rash_palm_lymphadenitis + rash_palm_rash_monomorphic + rash_palm_occ_lesion_un + rash_palm_occ_lesion_bi + rash_palm_sex_male + rash_palm_agecat2 + rash_palm_agecat3 + rash_palm_agecat4 + rash_soles_lymphadenitis + rash_soles_rash_monomorphic + rash_soles_occ_lesion_un + rash_soles_occ_lesion_bi + rash_soles_sex_male + rash_soles_agecat2 + rash_soles_agecat3 + rash_soles_agecat4 + lymphadenitis_rash_monomorphic + lymphadenitis_occ_lesion_bi + lymphadenitis_sex_male + lymphadenitis_agecat2 + lymphadenitis_agecat3 + lymphadenitis_agecat4 + rash_monomorphic_occ_lesion_bi + rash_monomorphic_sex_male + rash_monomorphic_agecat2 + rash_monomorphic_agecat3 + occ_lesion_bi_sex_male + occ_lesion_bi_agecat2 + occ_lesion_bi_agecat4 + sex_male_agecat2 + sex_male_agecat4
prior_summary(bas_linear_model)
##                 prior     class                           coef group resp dpar
##  student_t(3, 0, 2.5)         b                                               
##  student_t(3, 0, 2.5)         b                        agecat2                
##  student_t(3, 0, 2.5)         b                        agecat3                
##  student_t(3, 0, 2.5)         b                        agecat4                
##  student_t(3, 0, 2.5)         b                        agecat5                
##  student_t(3, 0, 2.5)         b               fever_beforerash                
##  student_t(3, 0, 2.5)         b fever_beforerash_occ_lesion_un                
##  student_t(3, 0, 2.5)         b     fever_beforerash_rash_palm                
##  student_t(3, 0, 2.5)         b                  lymphadenitis                
##  student_t(3, 0, 2.5)         b          lymphadenitis_agecat2                
##  student_t(3, 0, 2.5)         b          lymphadenitis_agecat3                
##  student_t(3, 0, 2.5)         b          lymphadenitis_agecat4                
##  student_t(3, 0, 2.5)         b    lymphadenitis_occ_lesion_bi                
##  student_t(3, 0, 2.5)         b lymphadenitis_rash_monomorphic                
##  student_t(3, 0, 2.5)         b         lymphadenitis_sex_male                
##  student_t(3, 0, 2.5)         b                  occ_lesion_bi                
##  student_t(3, 0, 2.5)         b          occ_lesion_bi_agecat2                
##  student_t(3, 0, 2.5)         b          occ_lesion_bi_agecat4                
##  student_t(3, 0, 2.5)         b         occ_lesion_bi_sex_male                
##  student_t(3, 0, 2.5)         b                  occ_lesion_un                
##  student_t(3, 0, 2.5)         b               rash_monomorphic                
##  student_t(3, 0, 2.5)         b       rash_monomorphic_agecat2                
##  student_t(3, 0, 2.5)         b       rash_monomorphic_agecat3                
##  student_t(3, 0, 2.5)         b rash_monomorphic_occ_lesion_bi                
##  student_t(3, 0, 2.5)         b      rash_monomorphic_sex_male                
##  student_t(3, 0, 2.5)         b                      rash_palm                
##  student_t(3, 0, 2.5)         b              rash_palm_agecat2                
##  student_t(3, 0, 2.5)         b              rash_palm_agecat3                
##  student_t(3, 0, 2.5)         b              rash_palm_agecat4                
##  student_t(3, 0, 2.5)         b        rash_palm_lymphadenitis                
##  student_t(3, 0, 2.5)         b        rash_palm_occ_lesion_bi                
##  student_t(3, 0, 2.5)         b        rash_palm_occ_lesion_un                
##  student_t(3, 0, 2.5)         b     rash_palm_rash_monomorphic                
##  student_t(3, 0, 2.5)         b           rash_palm_rash_soles                
##  student_t(3, 0, 2.5)         b             rash_palm_sex_male                
##  student_t(3, 0, 2.5)         b                     rash_soles                
##  student_t(3, 0, 2.5)         b             rash_soles_agecat2                
##  student_t(3, 0, 2.5)         b             rash_soles_agecat3                
##  student_t(3, 0, 2.5)         b             rash_soles_agecat4                
##  student_t(3, 0, 2.5)         b       rash_soles_lymphadenitis                
##  student_t(3, 0, 2.5)         b       rash_soles_occ_lesion_bi                
##  student_t(3, 0, 2.5)         b       rash_soles_occ_lesion_un                
##  student_t(3, 0, 2.5)         b    rash_soles_rash_monomorphic                
##  student_t(3, 0, 2.5)         b            rash_soles_sex_male                
##  student_t(3, 0, 2.5)         b                       sex_male                
##  student_t(3, 0, 2.5)         b               sex_male_agecat2                
##  student_t(3, 0, 2.5)         b               sex_male_agecat4                
##  student_t(3, 0, 2.5)         b                      vacv_scar                
##  student_t(3, 0, 2.5) Intercept                                               
##  nlpar lb ub       source
##                      user
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##              (vectorized)
##                   default
bas_sum <- summary(bas_linear_model, prob = 0.8)
bas_sum
##  Family: bernoulli 
##   Links: mu = logit 
## Formula: mpx_pcr_pos ~ vacv_scar + fever_beforerash + rash_palm + rash_soles + lymphadenitis + rash_monomorphic + occ_lesion_un + occ_lesion_bi + sex_male + agecat2 + agecat3 + agecat4 + agecat5 + fever_beforerash_rash_palm + fever_beforerash_occ_lesion_un + rash_palm_rash_soles + rash_palm_lymphadenitis + rash_palm_rash_monomorphic + rash_palm_occ_lesion_un + rash_palm_occ_lesion_bi + rash_palm_sex_male + rash_palm_agecat2 + rash_palm_agecat3 + rash_palm_agecat4 + rash_soles_lymphadenitis + rash_soles_rash_monomorphic + rash_soles_occ_lesion_un + rash_soles_occ_lesion_bi + rash_soles_sex_male + rash_soles_agecat2 + rash_soles_agecat3 + rash_soles_agecat4 + lymphadenitis_rash_monomorphic + lymphadenitis_occ_lesion_bi + lymphadenitis_sex_male + lymphadenitis_agecat2 + lymphadenitis_agecat3 + lymphadenitis_agecat4 + rash_monomorphic_occ_lesion_bi + rash_monomorphic_sex_male + rash_monomorphic_agecat2 + rash_monomorphic_agecat3 + occ_lesion_bi_sex_male + occ_lesion_bi_agecat2 + occ_lesion_bi_agecat4 + sex_male_agecat2 + sex_male_agecat4 
##    Data: bas_train_prepped (Number of observations: 57) 
##   Draws: 4 chains, each with iter = 4000; warmup = 2000; thin = 1;
##          total post-warmup draws = 8000
## 
## Population-Level Effects: 
##                                Estimate Est.Error l-80% CI u-80% CI Rhat
## Intercept                         -7.47      3.73   -12.25    -2.90 1.00
## vacv_scar                         -4.69      4.49   -10.06    -0.32 1.00
## fever_beforerash                  -0.49      2.71    -3.77     2.70 1.00
## rash_palm                          1.24      2.76    -1.98     4.79 1.00
## rash_soles                         0.74      2.51    -2.27     3.84 1.00
## lymphadenitis                      0.70      2.43    -2.20     3.75 1.00
## rash_monomorphic                   2.35      2.53    -0.60     5.62 1.00
## occ_lesion_un                      1.02      2.71    -2.20     4.44 1.00
## occ_lesion_bi                     -1.42      2.61    -4.66     1.65 1.00
## sex_male                          -0.61      2.37    -3.56     2.24 1.00
## agecat2                            0.02      2.33    -2.73     2.81 1.00
## agecat3                           -1.24      2.56    -4.55     1.74 1.00
## agecat4                           -1.91      2.84    -5.61     1.25 1.00
## agecat5                            2.78      4.05    -1.17     7.43 1.00
## fever_beforerash_rash_palm         0.22      2.46    -2.80     3.25 1.00
## fever_beforerash_occ_lesion_un     1.61      3.01    -1.86     5.36 1.00
## rash_palm_rash_soles               3.22      2.92    -0.13     7.09 1.00
## rash_palm_lymphadenitis           -2.44      2.82    -6.08     0.79 1.00
## rash_palm_rash_monomorphic         1.55      2.44    -1.35     4.73 1.00
## rash_palm_occ_lesion_un            4.56      8.57    -0.85    10.84 1.00
## rash_palm_occ_lesion_bi           -0.56      2.54    -3.69     2.45 1.00
## rash_palm_sex_male                 4.50      3.16     0.78     8.66 1.00
## rash_palm_agecat2                 -1.83      2.69    -5.16     1.27 1.00
## rash_palm_agecat3                  0.43      2.41    -2.42     3.41 1.00
## rash_palm_agecat4                 -1.03      2.50    -4.16     1.91 1.00
## rash_soles_lymphadenitis          -1.34      2.37    -4.31     1.46 1.00
## rash_soles_rash_monomorphic       -0.57      2.34    -3.52     2.26 1.00
## rash_soles_occ_lesion_un           2.78      5.10    -1.65     7.84 1.00
## rash_soles_occ_lesion_bi           0.42      2.32    -2.37     3.30 1.00
## rash_soles_sex_male                2.67      2.77    -0.58     6.40 1.00
## rash_soles_agecat2                -0.74      2.64    -3.95     2.33 1.00
## rash_soles_agecat3                 0.06      2.26    -2.76     2.83 1.00
## rash_soles_agecat4                 0.47      2.49    -2.53     3.58 1.00
## lymphadenitis_rash_monomorphic     1.71      2.55    -1.30     5.05 1.00
## lymphadenitis_occ_lesion_bi        0.11      2.46    -2.89     3.11 1.00
## lymphadenitis_sex_male             0.18      2.28    -2.61     3.05 1.00
## lymphadenitis_agecat2             -0.06      2.24    -2.84     2.71 1.00
## lymphadenitis_agecat3             -1.07      2.58    -4.32     1.97 1.00
## lymphadenitis_agecat4             -0.91      2.55    -4.06     2.10 1.00
## rash_monomorphic_occ_lesion_bi    -0.43      2.51    -3.47     2.55 1.00
## rash_monomorphic_sex_male          0.17      2.14    -2.43     2.80 1.00
## rash_monomorphic_agecat2           2.25      2.64    -0.78     5.74 1.00
## rash_monomorphic_agecat3          -0.19      2.30    -3.05     2.54 1.00
## occ_lesion_bi_sex_male             1.05      2.50    -1.89     4.12 1.00
## occ_lesion_bi_agecat2              0.92      2.36    -1.98     3.85 1.00
## occ_lesion_bi_agecat4             -0.05      2.63    -3.21     3.17 1.00
## sex_male_agecat2                  -1.72      2.59    -4.98     1.21 1.00
## sex_male_agecat4                  -2.24      2.84    -5.95     1.03 1.00
##                                Bulk_ESS Tail_ESS
## Intercept                          7250     5460
## vacv_scar                          5245     2505
## fever_beforerash                   8911     4443
## rash_palm                          7970     4885
## rash_soles                         8222     5659
## lymphadenitis                      8475     5117
## rash_monomorphic                   6521     4724
## occ_lesion_un                      8651     4790
## occ_lesion_bi                      7660     4276
## sex_male                           7576     5092
## agecat2                            7806     4840
## agecat3                            7869     3740
## agecat4                            6713     3712
## agecat5                            6098     3231
## fever_beforerash_rash_palm         8335     5309
## fever_beforerash_occ_lesion_un     8909     4583
## rash_palm_rash_soles               5941     4699
## rash_palm_lymphadenitis            7081     4270
## rash_palm_rash_monomorphic         6676     4095
## rash_palm_occ_lesion_un            3781     1825
## rash_palm_occ_lesion_bi            9842     4844
## rash_palm_sex_male                 6052     5168
## rash_palm_agecat2                  6945     4141
## rash_palm_agecat3                  7769     4829
## rash_palm_agecat4                  8306     5611
## rash_soles_lymphadenitis           8583     5110
## rash_soles_rash_monomorphic        8466     4807
## rash_soles_occ_lesion_un           4471     1732
## rash_soles_occ_lesion_bi           8941     5687
## rash_soles_sex_male                6945     4886
## rash_soles_agecat2                 8042     5324
## rash_soles_agecat3                 9285     5945
## rash_soles_agecat4                 8214     4807
## lymphadenitis_rash_monomorphic     7977     4713
## lymphadenitis_occ_lesion_bi        8523     5217
## lymphadenitis_sex_male             8206     5667
## lymphadenitis_agecat2              8720     5830
## lymphadenitis_agecat3              9237     5024
## lymphadenitis_agecat4              8737     4514
## rash_monomorphic_occ_lesion_bi     8414     5208
## rash_monomorphic_sex_male          8947     6061
## rash_monomorphic_agecat2           7379     4914
## rash_monomorphic_agecat3           8431     5702
## occ_lesion_bi_sex_male             7647     4804
## occ_lesion_bi_agecat2              8425     4407
## occ_lesion_bi_agecat4              8411     4921
## sex_male_agecat2                   7535     4339
## sex_male_agecat4                   6789     4790
## 
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

MCMC Diagnostics

plot(bas_linear_model)

bayesplot::mcmc_rank_overlay(bas_linear_model)

Model Coefficients

draws <- as_draws_df(bas_linear_model) |> 
  gather_draws(`b_.*`, regex=TRUE) |> 
  filter(!.variable %in% c("b_Intercept")) |> 
  mutate(.variable = as.factor(.variable)) |> 
  mutate(.variable = fct_reorder(.variable, .value, mean)) |> 
  group_by(.variable) |> 
  mutate(sig = (sum(.value > 0)/n()) > 0.95 | (sum(.value > 0)/n()) < 0.05) |> 
  ungroup() |> 
  mutate(.variable = if_else(sig, paste0(.variable, "*"), as.character(.variable)))
  

cplot <- ggplot(draws,
                aes(y = fct_reorder(.variable, .value, mean), x = .value, fill = stat((x > 0)))) +
  stat_halfeye(p_limits = c(0.01, 0.99)) +
  scale_x_continuous(limits = c(min(bas_sum$fixed[[3]]), max(bas_sum$fixed[[4]])))
cplot
## Warning: Removed 2949 rows containing missing values (stat_slabinterval).

draws |> 
  group_by(.variable) |> 
  summarize(
    mean_value = mean(.value),
    frac_pos = sum(.value > 0)/n()
  ) |> 
  arrange(frac_pos) |> 
  print(n = Inf)
## # A tibble: 47 × 3
##    .variable                        mean_value frac_pos
##    <chr>                                 <dbl>    <dbl>
##  1 b_vacv_scar                         -4.69     0.0746
##  2 b_rash_palm_lymphadenitis           -2.44     0.176 
##  3 b_sex_male_agecat4                  -2.24     0.208 
##  4 b_rash_palm_agecat2                 -1.83     0.237 
##  5 b_agecat4                           -1.91     0.242 
##  6 b_sex_male_agecat2                  -1.72     0.245 
##  7 b_rash_soles_lymphadenitis          -1.34     0.287 
##  8 b_occ_lesion_bi                     -1.42     0.297 
##  9 b_agecat3                           -1.24     0.318 
## 10 b_rash_palm_agecat4                 -1.03     0.335 
## 11 b_lymphadenitis_agecat3             -1.07     0.336 
## 12 b_lymphadenitis_agecat4             -0.908    0.368 
## 13 b_rash_soles_agecat2                -0.742    0.388 
## 14 b_sex_male                          -0.615    0.407 
## 15 b_rash_soles_rash_monomorphic       -0.565    0.408 
## 16 b_rash_palm_occ_lesion_bi           -0.556    0.418 
## 17 b_rash_monomorphic_occ_lesion_bi    -0.428    0.426 
## 18 b_fever_beforerash                  -0.486    0.427 
## 19 b_rash_monomorphic_agecat3          -0.191    0.467 
## 20 b_occ_lesion_bi_agecat4             -0.0543   0.492 
## 21 b_lymphadenitis_agecat2             -0.0599   0.496 
## 22 b_agecat2                            0.0218   0.506 
## 23 b_lymphadenitis_occ_lesion_bi        0.109    0.515 
## 24 b_rash_soles_agecat3                 0.0634   0.515 
## 25 b_lymphadenitis_sex_male             0.177    0.529 
## 26 b_rash_monomorphic_sex_male          0.170    0.532 
## 27 b_fever_beforerash_rash_palm         0.219    0.541 
## 28 b_rash_soles_occ_lesion_bi           0.417    0.564 
## 29 b_rash_palm_agecat3                  0.426    0.565 
## 30 b_rash_soles_agecat4                 0.468    0.574 
## 31 b_lymphadenitis                      0.703    0.619 
## 32 b_rash_soles                         0.740    0.625 
## 33 b_occ_lesion_bi_agecat2              0.921    0.652 
## 34 b_occ_lesion_un                      1.02     0.653 
## 35 b_occ_lesion_bi_sex_male             1.05     0.662 
## 36 b_rash_palm                          1.24     0.673 
## 37 b_fever_beforerash_occ_lesion_un     1.61     0.708 
## 38 b_rash_palm_rash_monomorphic         1.55     0.739 
## 39 b_rash_soles_occ_lesion_un           2.78     0.742 
## 40 b_lymphadenitis_rash_monomorphic     1.71     0.750 
## 41 b_agecat5                            2.78     0.784 
## 42 b_rash_monomorphic_agecat2           2.25     0.818 
## 43 b_rash_palm_occ_lesion_un            4.56     0.831 
## 44 b_rash_monomorphic                   2.35     0.839 
## 45 b_rash_soles_sex_male                2.67     0.845 
## 46 b_rash_palm_rash_soles               3.22     0.888 
## 47 b_rash_palm_sex_male                 4.50     0.950

Prediction

In-Sample:

predictions_in <- bas_train_prepped |> 
  bind_cols(predict(bas_linear_model)) |> 
  mutate(pred = factor(as.integer(Estimate > 0.5), levels = c("0", "1"))) |> 
  mutate(true = factor(mpx_pcr_pos, levels = c("0", "1")))
cmat_in <- conf_mat(predictions_in, truth = true, estimate = pred)
cmat_in
##           Truth
## Prediction  0  1
##          0 41  3
##          1  1 12
summary(cmat_in)
## # A tibble: 13 × 3
##    .metric              .estimator .estimate
##    <chr>                <chr>          <dbl>
##  1 accuracy             binary         0.930
##  2 kap                  binary         0.811
##  3 sens                 binary         0.976
##  4 spec                 binary         0.8  
##  5 ppv                  binary         0.932
##  6 npv                  binary         0.923
##  7 mcc                  binary         0.815
##  8 j_index              binary         0.776
##  9 bal_accuracy         binary         0.888
## 10 detection_prevalence binary         0.772
## 11 precision            binary         0.932
## 12 recall               binary         0.976
## 13 f_meas               binary         0.953
roc_dat_in <- roc_curve(predictions_in, true, Estimate, event_level = "second")
roc_auc(predictions_in, true, Estimate, event_level = "second")
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.986
ggplot(roc_dat_in, aes(x = 1 - specificity, y = sensitivity)) +
  geom_path() +
  geom_abline(lty=3) +
  coord_equal() +
  theme_bw()

Out-Of-Sample:

predictions_out <- bas_validate_prepped |> 
  bind_cols(predict(bas_linear_model, newdata = bas_validate_prepped)) |> 
  mutate(pred = factor(as.integer(Estimate > 0.5), levels = c("0", "1"))) |> 
  mutate(true = factor(mpx_pcr_pos, levels = c("0", "1")))
cmat_out <- conf_mat(predictions_out, truth = true, estimate = pred)
cmat_out
##           Truth
## Prediction  0  1
##          0 12  4
##          1  3  2
summary(cmat_out)
## # A tibble: 13 × 3
##    .metric              .estimator .estimate
##    <chr>                <chr>          <dbl>
##  1 accuracy             binary         0.667
##  2 kap                  binary         0.140
##  3 sens                 binary         0.8  
##  4 spec                 binary         0.333
##  5 ppv                  binary         0.75 
##  6 npv                  binary         0.4  
##  7 mcc                  binary         0.141
##  8 j_index              binary         0.133
##  9 bal_accuracy         binary         0.567
## 10 detection_prevalence binary         0.762
## 11 precision            binary         0.75 
## 12 recall               binary         0.8  
## 13 f_meas               binary         0.774
roc_dat_out <- roc_curve(predictions_out, true, Estimate, event_level = "second")
roc_auc(predictions_out, true, Estimate, event_level = "second")
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 roc_auc binary         0.622
ggplot(roc_dat_out, aes(x = 1 - specificity, y = sensitivity)) +
  geom_path() +
  geom_abline(lty=3) +
  coord_equal() +
  theme_bw()